// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1

package promising

import (
	"context"
	"fmt"
	"sync"
	"sync/atomic"

	"go.opentelemetry.io/otel/attribute"
	"go.opentelemetry.io/otel/trace"
)

// promise represents a result that will become available at some point
// in the future, delivered by an asynchronous [Task].
type promise struct {
	name string

	responsible atomic.Pointer[task]
	result      atomic.Pointer[promiseResult]
	traceSpan   trace.Span

	waiting   []chan<- struct{}
	waitingMu sync.Mutex
}

func (p *promise) promiseID() PromiseID {
	return PromiseID{p}
}

type promiseResult struct {
	val any
	err error

	// forced is set when this result was generated by the promise machinery
	// itself, as opposed to from calling tasks. We use this to behave more
	// gracefully when the responsible task resolution races with the internal
	// error, so that we can treat that differently to when the responsible
	// task itself tries to resolve a promise multiple times.
	forced bool
}

func getResolvedPromiseResult[T any](result *promiseResult) (T, error) {
	// v might fail this type assertion if it's been set to nil
	// due to its responsible task exiting without resolving it,
	// in which case we'll just return the zero value of T along
	// with the error.
	v, _ := result.val.(T)
	err := result.err
	return v, err
}

// PromiseID is an opaque, comparable unique identifier for a promise, which
// can therefore be used by callers to produce a lookup table of metadata for
// each active promise they are interested in.
//
// The identifier for a promise follows it as the responsibility to resolve it
// transfers beween tasks.
//
// For example, this can be useful for retaining contextual information that
// can help explain which work was implicated in a dependency cycle between
// tasks.
type PromiseID struct {
	promise *promise
}

func (id PromiseID) FriendlyName() string {
	return id.promise.name
}

// NoPromise is the zero value of [PromiseID] and used to represent the absense
// of a promise.
var NoPromise PromiseID

// NewPromise creates a new promise that the calling task is initially
// responsible for and returns both its resolver and its getter.
//
// The given context must be a task context or this function will panic.
//
// The caller should retain the resolver for its own use and pass the getter
// to any other tasks that will consume the result of the promise.
func NewPromise[T any](ctx context.Context, name string) (PromiseResolver[T], PromiseGet[T]) {
	callerSpan := trace.SpanFromContext(ctx)
	initialResponsible := mustTaskFromContext(ctx)
	p := &promise{name: name}
	p.responsible.Store(initialResponsible)
	initialResponsible.responsible[p] = struct{}{}

	ctx, span := tracer.Start(
		ctx, fmt.Sprintf("promise(%s)", name),
		trace.WithNewRoot(),
		trace.WithLinks(trace.Link{
			SpanContext: trace.SpanContextFromContext(ctx),
		}),
	)
	_ = ctx // prevent staticcheck from complaining until we have something actually using this
	p.traceSpan = span
	promiseSpanContext := span.SpanContext()

	callerSpan.AddEvent("new promise", trace.WithAttributes(
		attribute.String("promising.responsible_for", promiseSpanContext.SpanID().String()),
	))

	resolver := PromiseResolver[T]{p}
	getter := PromiseGet[T](func(ctx context.Context) (T, error) {
		reqT := mustTaskFromContext(ctx)

		waiterSpan := trace.SpanFromContext(ctx)

		ok := reqT.awaiting.CompareAndSwap(nil, p)
		if !ok {
			// If we get here then the task seems to have forked into two
			// goroutines that are trying to await promises concurrently,
			// which is illegal per the contract for tasks.
			panic("racing promise get")
		}
		defer func() {
			ok := reqT.awaiting.CompareAndSwap(p, nil)
			if !ok {
				panic("racing promise get")
			}
		}()

		// We'll first test whether waiting for this promise is possible
		// without creating a deadlock, by following the awaiting->responsible
		// chain.
		checkP := p
		checkT := p.responsible.Load()
		steps := 1
		for checkT != reqT {
			steps++
			if checkT == nil {
				break
			}
			nextCheckP := checkT.awaiting.Load()
			if nextCheckP == nil {
				break
			}
			if checkP.responsible.Load() != checkT {
				break
			}
			checkP = nextCheckP
			checkT = checkP.responsible.Load()
		}
		if checkT == reqT {
			// We've found a self-dependency, but to report it in a useful
			// way we need to collect up all of the promises, so we'll
			// repeat the above and collect up all of the promises we find
			// along the way this time, instead of just counting them.
			err := make(ErrSelfDependent, 0, steps)
			var affectedPromises []*promise
			checkP := p
			checkT := p.responsible.Load()
			err = append(err, checkP.promiseID())
			affectedPromises = append(affectedPromises, checkP)
			for checkT != reqT {
				if checkT == nil {
					break
				}
				nextCheckP := checkT.awaiting.Load()
				if nextCheckP == nil {
					break
				}
				if checkP.responsible.Load() != checkT {
					break
				}
				checkP = nextCheckP
				checkT = checkP.responsible.Load()
				err = append(err, checkP.promiseID())
				affectedPromises = append(affectedPromises, checkP)
			}
			waiterSpan.AddEvent(
				"task is self-dependent",
				trace.WithAttributes(
					attribute.String("promise.waiting_for_id", promiseSpanContext.SpanID().String()),
				),
			)

			// All waiters for this promise need to see this error, because
			// otherwise the other waiters might stall forever waiting for
			// a result that will never come.
			for _, affected := range affectedPromises {
				resolvePromiseInternalFailure(affected, err)
			}
			// The current promise is one of the "affected promises" that
			// were resolved above, so we can now fall through to the check
			// below for whether the promise is already resolved and have
			// it return the error.
		}

		// If we get here then it's safe to actually await.
		p.waitingMu.Lock()
		if result := p.result.Load(); result != nil {
			// No need to wait because the result is already available.
			p.waitingMu.Unlock()
			waiterSpan.AddEvent(
				"promise is already resolved",
				trace.WithAttributes(
					attribute.String("promise.waiting_for_id", promiseSpanContext.SpanID().String()),
				),
			)
			return getResolvedPromiseResult[T](result)
		}

		ch := make(chan struct{})
		p.waiting = append(p.waiting, ch)
		waiterCount := len(p.waiting)
		p.waitingMu.Unlock()

		waiterSpan.AddEvent(
			"waiting for promise result",
			trace.WithAttributes(
				attribute.String("promise.waiting_for_id", promiseSpanContext.SpanID().String()),
				attribute.Int("promise.waiter_count", waiterCount),
			),
		)
		p.traceSpan.AddEvent(
			"new task waiting",
			trace.WithAttributes(
				attribute.String("promise.waiter_id", waiterSpan.SpanContext().SpanID().String()),
				attribute.Int("promise.waiter_count", waiterCount),
			),
		)
		<-ch // channel will be closed once promise is resolved
		waiterSpan.AddEvent(
			"promise resolved",
			trace.WithAttributes(
				attribute.String("promise.waiting_for_id", promiseSpanContext.SpanID().String()),
			),
		)
		if result := p.result.Load(); result != nil {
			return getResolvedPromiseResult[T](result)
		} else {
			// If we get here then there's a bug in resolvePromise below
			panic("promise signaled resolved but has no result")
		}
	})

	return resolver, getter
}

func resolvePromise(p *promise, v any, err error) {
	p.waitingMu.Lock()
	defer p.waitingMu.Unlock()

	respT := p.responsible.Load()
	p.responsible.Store(nil)
	respT.responsible.Remove(p)

	ok := p.result.CompareAndSwap(nil, &promiseResult{
		val: v,
		err: err,
	})
	if !ok {
		// The result that's now present might be a "forced error" generated
		// through promiseInternalFailure, in which case we just quietly
		// ignore the attempt to actually resolve it since all of the
		// waiters will already have received the error.
		r := p.result.Load()
		if r != nil && r.forced {
			return
		}
		// Any other conflict indicates a bug in the calling task.
		panic("promise resolved more than once")
	}

	for _, waitingCh := range p.waiting {
		close(waitingCh)
	}
	p.waiting = nil
}

// resolvePromiseInternalFailure is a variant of resolvePromise that we use for
// internal errors that aren't produced by the task responsible for the
// promise, such as when tasks become self-dependent and so we need to
// immediately fail all of the promises in the chain to prevent any of
// the waiters from potentially stuck forever waiting for completion that
// might never come, or might see an incorrect result while the failures
// propagate through a different return path.
func resolvePromiseInternalFailure(p *promise, err error) {
	p.waitingMu.Lock()
	defer p.waitingMu.Unlock()

	p.traceSpan.AddEvent("internal promise failure", trace.WithAttributes(
		attribute.String("error", err.Error()),
	))

	// For internal failures we leave the responsibility data in place so
	// that the responsible task can still try to resolve the promise and
	// have it be a no-op, since the task that's responsible for resolving
	// will not typically also call the promise getter, and so it won't
	// know about the failure.

	ok := p.result.CompareAndSwap(nil, &promiseResult{
		err:    err,
		forced: true,
	})
	if !ok {
		// This suggests either that the responsible task beat us to the punch
		// and resolved first, or that this promise was involved in two
		// different self-dependence situations simultaneously and a different
		// one got recorded already.
		//
		// Both situations are no big deal -- the promise got resolved one
		// way or another -- but we'll record a tracing event for it just
		// in case it's helpful while debugging something.
		p.traceSpan.AddEvent("internal promise failure conflict")
	}

	for _, waitingCh := range p.waiting {
		close(waitingCh)
	}
	p.waiting = nil
}

// PromiseGet is the signature of a promise "getter" function, which blocks
// until a promise is resolved and then returns its result values.
//
// A PromiseGet function may be called only within a task, using a context
// value that descends from that task's context.
//
// If the given context is cancelled or reaches its deadline then the function
// will return the relevant context-related error to describe that situation.
type PromiseGet[T any] func(ctx context.Context) (T, error)
